# -*- coding: utf-8 -*-
# @Time : 2021-11-17 19:11
# @Author : lwb
# @Site : 
# @File : LSTMprocess.py
import torch
from torch.nn.utils.rnn import pad_sequence
# 这个函数也可不用，只需要将他的操作，在dataset中return 数据的时候应该就行了
# collate_fn参数指向一个函数，用于对一个批次的样本进行整理，如将其转换成张量
def collate_fn(examples):
    lenghts=torch.tensor([len(ex[0]) for ex in examples])
    inputs = [torch.tensor(ex[0]) for ex in examples]
    # 输出的目标targets为该批次中全部样例输出结果（0或者1）构成的张量
    targets = torch.tensor([ex[1] for ex in examples], dtype=torch.long)
    inputs = pad_sequence(inputs,batch_first=True)
    return inputs,lenghts, targets